package ch4

import (
	"database/sql"
	"fmt"

	_ "github.com/go-sql-driver/mysql"
)

var db *sql.DB

// initDB 初始化数据库连接
func initDB() (err error) {
	// 注意这里一定不能使用“:=”，因为这样会让db变成一个局部变量，其作用范围只能在initDB函数里
	db, err = sql.Open("mysql", "root:a123456!@tcp(localhost:3306)/ch4")
	if err != nil {
		return err
	}
	err = db.Ping()
	if err != nil {
		return err
	}
	return nil
}

// queryRow 根据Id查询一行数据并打印在控制台
func queryRow(id int) {
	var sql = "select uid, name, phone from `USER` where uid = ?"
	var u User
	err := db.QueryRow(sql, id).Scan(&u.Uid, &u.Name, &u.Phone)
	if err != nil {
		fmt.Printf("scan failed, err:%v\n", err)
	}
	fmt.Printf("uid:%d, name:%s, phone:%s\n", u.Uid, u.Name, u.Phone)
}

// queryManyRows 查询多条记录打印
func queryManyRows(id int) {
	sql := "select uid, name, phone from `USER` where uid >= ?"
	rows, err := db.Query(sql, id)
	if err != nil {
		fmt.Printf("query failed, err:%v\n", err)
		return
	}
	defer rows.Close()
	for rows.Next() {
		var u User
		rows.Scan(&u.Uid, &u.Name, &u.Phone)
		fmt.Printf("uid:%d, name:%s, phone:%s\n", u.Uid, u.Name, u.Phone)
	}
}

// insertRows 插入数据
func insertRows(users []User) {
	sql := "insert into USER(name, phone) values (?,?)"
	for _, user := range users {
		ret, err := db.Exec(sql, user.Name, user.Phone)
		if err != nil {
			fmt.Printf("query failed, err:%v\n", err)
			return
		}
		uid, _ := ret.LastInsertId()
		fmt.Printf("add success, uid is%d\n", uid)
	}
}

// 更改一条记录
func updateRows(user User) {
	sql := "update USER set phone = ? where uid = ?"
	ret, err := db.Exec(sql, user.Phone, user.Uid)
	if err != nil {
		fmt.Printf("query failed, err:%v\n", err)
		return
	}
	rows, _ := ret.RowsAffected()
	fmt.Printf("update success, affective rows is %d\n", rows)
}

// prepareQuery 待预处理的查询
func prepareQuery(id int) {
	var sql = "select uid, name, phone from `USER` where uid = ?"
	statement, err := db.Prepare(sql)
	if err != nil {
		fmt.Printf("prepare failed, err:%v\n", err)
		return
	}
	defer statement.Close()

	rows := statement.QueryRow(id)

	var u User
	e := rows.Scan(&u.Uid, &u.Name, &u.Phone)
	if e != nil {
		fmt.Printf("scan failed, err:%v\n", err)
		return
	}
	fmt.Printf("uid:%d, name:%s, phone:%s\n", u.Uid, u.Name, u.Phone)
}

func insertWithTrans(user User) {
	tx, err := db.Begin()
	if err != nil {
		if tx != nil {
			tx.Rollback()
		}
		fmt.Printf("open transaction err%v\n", err)
		return
	}
	sql := "insert into USER(name, phone) values (?,?)"
	statement, _ := tx.Prepare(sql)
	ret, err := statement.Exec(user.Name, user.Phone)
	if err != nil {
		tx.Rollback()
		fmt.Printf("add error rollback%v\n", err)
		return
	}
	defer statement.Close()
	tx.Commit()
	uid, _ := ret.LastInsertId()
	fmt.Printf("add success, uid is%d\n", uid)
}
